## @package crf
# Module caffe2.python.crf
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, recurrent, model_helper, brew
import numpy as np

'''
Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
In order to support batch_size > 1, we will have to implement the CRFUnit
and its gradient in C++ and handle the different batches there.
'''


class CRFWithLoss(object):
    def __init__(self, model, num_classes, transitions_blob=None):
        self.model = model
        self.num_classes = num_classes
        self.num_classes_padded = num_classes + 2  # After adding BOS and EOS
        if not transitions_blob:
            transitions_blob = self.model.param_init_net.UniformFill(
                [],
                [core.ScopedBlobReference('crf_transitions')],
                shape=[self.num_classes_padded, self.num_classes_padded],
                min=-1.0,
                max=1.0
            )
        self.transitions = transitions_blob
        self.model.params.append(self.transitions)

    def crf_loss(self, predictions, labels, seq_lengths=None):
        # Since the transitions matrix is a shared parameter, need to
        # take a snapshot of it at the beginning since it can be updated
        # in between the operators that uses it when doing parallel updates
        transitions_snapshot = self.model.net.Copy(
            self.transitions, core.ScopedBlobReference('transitions_snapshot')
        )
        # Compute best path unary score from the logits
        path_unary_score = self._gather_entries_sum(
            predictions, labels, self.num_classes
        )
        # Append BOS and EOS entries to the predictions and labels
        predictions = self._pad_predictions(predictions)
        labels = self._pad_labels(labels)
        # Compute best path binary scores from the transitions matrix
        path_binary_score = self._path_binary_scores(
            labels, transitions_snapshot, seq_lengths
        )
        path_total_score = self.model.net.Add(
            [path_binary_score, path_unary_score],
            core.ScopedBlobReference('path_total')
        )
        # Compute all paths score
        zero_index = self.model.param_init_net.ConstantFill(
            [], shape=[1], value=0
        )
        initial_state = self.model.net.Gather(
            [predictions, zero_index],
            core.ScopedBlobReference('rnn_initial'),
            dense_gradient=True
        )
        input_data, _ = self.model.net.RemovePadding(
            [predictions],
            padding_width=1,
            end_padding_width=0,
            outputs=2,
        )
        input_data = self.model.net.ExpandDims(
            [input_data],
            core.ScopedBlobReference('rnn_input_data'),
            dims=[1]
        )
        # Due to a bug in RecurrentNetworkGradientOp, we need to copy the
        # transitions blob before sending it to the recurrent network
        transitions_copy = self.model.net.Copy(
            transitions_snapshot, core.ScopedBlobReference('transitions_copy')
        )
        all_paths_scores = self._crf_forward(
            input_data, initial_state, transitions_copy
        )
        loss = self.model.net.Sub(
            [all_paths_scores, path_total_score],
            core.ScopedBlobReference('crf_loss')
        )
        return loss

    def _pad_predictions(self, predictions):
        # This function will introduce two labels for beginning of sequence
        # And end of sequence, it will make the necessary udpates to the
        # the predictions blob

        low_score = -1000.0  # An arbitray very low number
        b_scores = np.array(
            [[low_score] * self.num_classes + [0, low_score]]
        ).astype(np.float32)

        e_scores = np.array(
            [[low_score] * self.num_classes + [low_score, 0]]
        ).astype(np.float32)

        b_scores = self.model.param_init_net.GivenTensorFill(
            [], "b_scores", shape=[1, self.num_classes_padded], values=b_scores
        )
        e_scores = self.model.param_init_net.GivenTensorFill(
            [], "e_scores", shape=[1, self.num_classes_padded], values=e_scores
        )

        zero_index = self.model.net.ConstantFill(
            [], shape=[1, ], value=0
        )
        length = self.model.net.Gather(
            [self.model.net.Shape([predictions]), zero_index],
        )
        length = self.model.net.Cast(length, to='int32')
        t_range = self.model.net.LengthsRangeFill(length)
        padding = self.model.net.ConstantFill([t_range], value=low_score)
        padding = self.model.net.ExpandDims(padding, dims=[1])
        padded_predictions, _ = self.model.net.Concat(
            [predictions, padding, padding],
            outputs=2,
            axis=1
        )
        padded_predictions_concat, _ = self.model.net.Concat(
            [b_scores, padded_predictions, e_scores],
            outputs=2,
            axis=0
        )
        return padded_predictions_concat

    def _pad_labels(self, labels):
        bos_i = self.num_classes
        eos_i = self.num_classes + 1
        bos_i_b = self.model.param_init_net.ConstantFill(
            [], shape=[1], value=bos_i
        )
        eos_i_b = self.model.param_init_net.ConstantFill(
            [], shape=[1], value=eos_i
        )
        labels = self.model.net.Cast([labels], to='int64')
        padded_labels, _ = self.model.net.Concat(
            [bos_i_b, labels, eos_i_b],
            axis=0,
            outputs=2
        )
        return padded_labels

    def _path_binary_scores(self, labels, transitions, seq_lengths=None):
        column_ids, _ = self.model.net.RemovePadding(
            [labels],
            outputs=2,
            padding_width=1,
            end_padding_width=0
        )
        row_ids, _ = self.model.net.RemovePadding(
            [labels],
            outputs=2,
            padding_width=0,
            end_padding_width=1
        )
        # Since there is no multi-dimensional gather, I flatten the matrix to
        # a 1-d vector and transform the ids to (row_ids * num_columns +
        # column_ids) and do gather in 1-d
        num_columns_blob = self.model.net.ConstantFill(
            [row_ids],
            value=self.num_classes_padded,
        )
        flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
        flattened_ids = self.model.net.Add([flattened_ids, column_ids])
        flattened_transitions = self.model.net.FlattenToVec([transitions])
        entries = self.model.net.Gather(
            [flattened_transitions, flattened_ids],
            dense_gradient=True
        )
        return self.model.ReduceFrontSum(entries)

    def _gather_entries_sum(self, in_data, indices, index_size):
        indices = self.model.net.Cast([indices], to='int64')
        index_size_blob = self.model.param_init_net.ConstantFill(
            [],
            shape=[1],
            value=index_size,
        )
        query_one_hot = self.model.net.OneHot(
            [indices, index_size_blob]
        )
        flattend_query = self.model.net.FlattenToVec(query_one_hot)
        flattend_data = self.model.net.FlattenToVec(in_data)
        query_scores = self.model.net.DotProduct(
            [flattend_query, flattend_data]
        )
        final_sum = self.model.net.ReduceFrontSum([query_scores])
        return final_sum

    def _crf_forward(
        self,
        input_blob,
        initial_state,
        transitions_copy,
        seq_lengths=None
    ):
        # Build the RNN net and get the last timestep output
        out_last = self.build_crf_net(
            input_blob, initial_state, transitions_copy
        )
        out_last, _ = self.model.net.Reshape(
            [out_last],
            outputs=2,
            shape=(self.num_classes_padded,)
        )
        zero_segment_id = self.model.param_init_net.ConstantFill(
            [],
            value=0,
            shape=[self.num_classes_padded],
            dtype=core.DataType.INT32,
        )

        # Compute the accumlated total score of all the paths
        accum_score = self.model.net.SortedSegmentRangeLogSumExp(
            [out_last, zero_segment_id]
        )
        accum_score, _ = self.model.net.Reshape(
            accum_score,
            outputs=2,
            shape=()
        )
        return accum_score

    def build_crf_net(self, input_blob, initial_state, transitions):
            '''
            Adds the crf_net recurrent operator to the model.

            model: model_helper.ModelHelper object new operators would be added
            to

            input_blob: the input sequence in a format T x N x D
            where T is sequence size, N - batch size and D - input dimention
            ##Only supports batch-size 1##

            seq_lengths: blob containing sequence lengths (unused)
            '''

            scope = 'crf_net'

            def s(name):
                ''
                # We have to manually scope due to our internal/external blob
                # relationships.
                return "{}/{}".format(str(scope), str(name))

            step_model = model_helper.ModelHelper(name='crf_step',
                                                  param_model=self.model)
            input_t, cell_t_prev, _ = (
                step_model.net.AddExternalInputs(
                    core.ScopedBlobReference('input_t'),
                    core.ScopedBlobReference('cell_t_prev'),
                    transitions
                )
            )
            zero_segment_id = step_model.param_init_net.ConstantFill(
                [],
                [s('zero_segment_id')],
                value=0,
                shape=[self.num_classes_padded],
                dtype=core.DataType.INT32,
            )

            # A hack to bypass model cloning for test
            step_model.param_init_net.AddExternalOutput(zero_segment_id)
            """ the CRF step """
            # Do tile
            prev_transpose = brew.transpose(
                step_model,
                cell_t_prev,
                [s('prev_transpose')],
                axes=(0, 2, 1),
            )
            prev_tiled = step_model.net.Tile(
                prev_transpose,
                [s('prev_tiled')],
                tiles=self.num_classes_padded,
                axis=2,
            )
            input_t_tiled = step_model.net.Tile(
                input_t,
                [s('input_t_tiled')],
                tiles=self.num_classes_padded,
                axis=1,
            )
            input_with_prev = step_model.net.Add(
                [prev_tiled, input_t_tiled],
                [s('input_with_prev')]
            )
            all_with_transitions = step_model.net.Add(
                [input_with_prev, transitions],
                [s('prev_with_transitions')],
                broadcast=1,
                use_grad_hack=1,
            )
            all_with_transitions_reshaped, _ = step_model.net.Reshape(
                all_with_transitions,
                [s('all_with_transitions_reshaped'), s('all_with_transitions_orig')],
                shape=(self.num_classes_padded, self.num_classes_padded)
            )
            cell_t = step_model.net.SortedSegmentRangeLogSumExp(
                [all_with_transitions_reshaped, zero_segment_id],
                [s('cell_t')],
            )
            step_model.net.AddExternalOutputs(cell_t)
            """ recurrent network """
            cell_input_blob = initial_state
            out_all, out_last = recurrent.recurrent_net(
                net=self.model.net,
                cell_net=step_model.net,
                inputs=[(input_t, input_blob)],
                initial_cell_inputs=[
                    (cell_t_prev, cell_input_blob),
                ],
                links={
                    cell_t_prev: cell_t,
                },
                scope=scope,
                outputs_with_grads=(1,)
            )
            return out_last

    def update_predictions(self, classes):

        def crf_update_predictions_op(inputs, outputs):
            # This operator will compute the best path of classes by performing
            # Viterbi decoding and then updates the predictions to make the tag
            # On the best path has the highest score among the others
            predictions = inputs[0].data
            transitions = inputs[1].data
            predictions = inputs[0].data
            predictions_shape = inputs[0].shape
            outputs[0].reshape(predictions_shape)

            trellis = np.zeros(predictions_shape)
            backpointers = np.zeros(predictions_shape, dtype=np.int32)
            trellis[0] = predictions[0]

            for t in range(1, predictions_shape[0]):
                v = np.expand_dims(trellis[t - 1], 1) + transitions
                trellis[t] = predictions[t] + np.max(v, 0)
                backpointers[t] = np.argmax(v, 0)

            viterbi = [np.argmax(trellis[-1])]
            for bp in reversed(backpointers[1:]):
                viterbi.append(bp[viterbi[-1]])
            viterbi.reverse()

            new_predictions = np.zeros(predictions_shape)
            old_bests = []
            for i, w_predictions in enumerate(predictions):
                # Get the current tag with the maximum score
                new_predictions[i] = predictions[i]
                old_best = np.argmax(w_predictions)
                old_bests.append(old_best)
                # Swap the scores of the current best tag and the tag on the
                # Viterbi path
                w_predictions[viterbi[i]], w_predictions[old_best] = \
                    w_predictions[old_best], w_predictions[viterbi[i]]
                new_predictions[i] = w_predictions
            # Remove the BOS and EOS entries from the predictions matrix
            orig_predictions = new_predictions[1:-1, 0:-2]
            outputs[0].reshape(orig_predictions.shape)
            outputs[0].data[...] = orig_predictions
        padded_classes = self._pad_predictions(classes)
        new_classes = self.model.net.Python(crf_update_predictions_op)(
            [padded_classes, self.transitions],
            core.ScopedBlobReference('post_crf_classes')
        )
        return new_classes
